Aalen’s Dynamic Path Model

Causal Inference with Time Varying Effects in PyMC

path-models
sem
causal inference

hello world

import pymc as pm
import numpy as np 
import pandas as pd
import arviz as az
import pytensor.tensor as pt
df = pd.read_csv("aalen_simdata.csv")
df = df[['subject', 'x', 'dose', 'M', 'start', 'stop', 'event']]
df.head()
subject x dose M start stop event
0 1 0 ctrl 6.74 0 4.00 0
1 1 0 ctrl 6.91 4 8.00 0
2 1 0 ctrl 6.90 8 12.00 0
3 1 0 ctrl 6.71 12 26.00 0
4 1 0 ctrl 6.45 26 46.85 1
df.groupby(['x', 'dose'])[['event', 'M']].agg(['mean', 'sum'])
event M
mean sum mean sum
x dose
0 ctrl 0.164179 66 6.996915 2812.76
1 high 0.119205 54 8.081589 3660.96
low 0.139037 52 7.302620 2731.18
Code
import matplotlib.pyplot as plt
import pandas as pd

# Derive subject-level info for ordering
subject_info = (
    df.groupby('subject')
      .agg(
          x=('x', 'first'),
          max_stop=('stop', 'max')
      )
      .sort_values(['x', 'max_stop'])
)

subjects = subject_info.index.tolist()
subject_to_y = {s: i for i, s in enumerate(subjects)}

fig, ax = plt.subplots(figsize=(8, 0.1 * len(subjects)))

for _, row in df.iterrows():
    y = subject_to_y[row['subject']]
    
    color = 'tab:blue' if row['x'] == 1 else 'tab:orange'
    
    ax.hlines(
        y=y,
        xmin=row['start'],
        xmax=row['stop'],
        color=color,
        linewidth=3
    )
    
    if row['event'] == 1:
        ax.plot(
            row['stop'],
            y,
            marker='o',
            color='red',
            markersize=6,
            zorder=3
        )

# Axis formatting
ax.set_yticks(range(len(subjects)))
ax.set_yticklabels(subjects)
ax.set_xlabel("Time")
ax.set_ylabel("Subject")

# Visual separation between treatment groups
x0_count = (subject_info['x'] == 0).sum()
ax.axhline(x0_count - 0.5, color='black', linestyle='--', linewidth=1)

# Legend
from matplotlib.lines import Line2D

legend_elements = [
    Line2D([0], [0], color='tab:blue', lw=3, label='x = 1'),
    Line2D([0], [0], color='tab:orange', lw=3, label='x = 0'),
    Line2D([0], [0], marker='o', color='red', lw=0, label='Event', markersize=6)
]

ax.legend(handles=legend_elements, loc='upper right')

ax.set_title("Subject Timelines Ordered by Treatment Level")

plt.tight_layout()
plt.show()

Data Preparation

def prepare_aalen_dpa_data(
    df,
    subject_col="subject",
    start_col="start",
    stop_col="stop",
    event_col="event",
    x_col="x",
    m_col="M",
):
    """
    Prepare Andersen–Gill / Aalen dynamic path data for PyMC.

    Parameters
    ----------
    df : pd.DataFrame
        Long-format start–stop survival data
    subject_col : str
        Subject identifier
    start_col, stop_col : str
        Interval boundaries
    event_col : str
        Event indicator (0/1)
    x_col : str
        Exposure / treatment
    m_col : str
        Mediator measured at interval start

    Returns
    -------
    dict
        Dictionary of numpy arrays ready for PyMC
    """

    df = df.copy()

    # -------------------------------------------------
    # 1. Basic quantities
    # -------------------------------------------------
    df["dt"] = df[stop_col] - df[start_col]

    if (df["dt"] <= 0).any():
        raise ValueError("Non-positive interval lengths detected.")

    N = df[event_col].astype(int).values
    Y = np.ones(len(df), dtype=int)  # Andersen–Gill at-risk indicator

    # -------------------------------------------------
    # 2. Time-bin indexing (piecewise-constant effects)
    # -------------------------------------------------
    bins = (
        df[[start_col, stop_col]]
        .drop_duplicates()
        .sort_values([start_col, stop_col])
        .reset_index(drop=True)
    )
    bins["bin_idx"] = np.arange(len(bins))

    df = df.merge(
        bins,
        on=[start_col, stop_col],
        how="left",
        validate="many_to_one"
    )

    bin_idx = df["bin_idx"].values
    n_bins = bins.shape[0]

    # -------------------------------------------------
    # 3. Center covariates (important for Aalen models)
    # -------------------------------------------------
    df["x_c"] = df[x_col]
    df["m_c"] = df[m_col] - df[m_col].mean()

    x = df["x_c"].values
    m = df["m_c"].values

    # -------------------------------------------------
    # 4. Predictable mediator (lag within subject)
    # -------------------------------------------------
    df = df.sort_values([subject_col, start_col])

    df["m_lag"] = (
        df.groupby(subject_col)["m_c"]
          .shift(1)
          .fillna(0.0)
    )

    m_lag = df["m_lag"].values

    df["I_low"]  = (df["dose"] == "low").astype(int)
    df["I_high"] = (df["dose"] == "high").astype(int)

    # -------------------------------------------------
    # 5. Assemble output
    # -------------------------------------------------
    data = {
        "N": N,
        "Y": Y,
        "dt": df["dt"].values,
        "bin_idx": bin_idx,
        "x": x,
        "m_lag": m_lag,
        "n_bins": n_bins,
        "bins": bins,     # useful for plotting
        "df_long": df     # optional: debugging / inspection
    }

    return data
data = prepare_aalen_dpa_data(df)
df_long = data['df_long']
df_long[['subject', 'x', 'dose', 'M', 'event', 'dt', 'bin_idx']].head(14)
subject x dose M event dt bin_idx
0 1 0 ctrl 6.74 0 4.00 7
1 1 0 ctrl 6.91 0 4.00 13
2 1 0 ctrl 6.90 0 4.00 23
3 1 0 ctrl 6.71 0 14.00 53
4 1 0 ctrl 6.45 1 20.85 81
5 2 1 high 6.11 0 4.00 7
6 2 1 high 6.28 0 4.00 13
7 2 1 high 7.04 0 4.00 23
8 2 1 high 6.93 0 14.00 53
9 2 1 high 7.86 0 26.00 89
10 2 1 high 8.47 0 26.00 115
11 2 1 high 8.91 0 26.00 137
12 2 1 high 8.99 0 52.00 162
13 2 1 high 9.36 0 104.00 188
N       = data["N"]
Y       = data["Y"]
dt      = data["dt"]
bin_idx = data["bin_idx"]
x       = data["x"]
m_lag   = data["m_lag"]
n_bins  = data["n_bins"]
bins    = data["bins"]
df_long = data["df_long"]

dt_bins = bins["stop"].values - bins["start"].values
m = df_long["m_c"].values
b = bin_idx
from scipy.interpolate import BSpline

def create_bspline_basis(n_bins, n_knots=10, degree=3):
    """
    Create B-spline basis functions for smooth time-varying effects.
    
    Parameters
    ----------
    n_bins : int
        Number of time bins
    n_knots : int
        Number of internal knots (fewer = smoother)
    degree : int
        Degree of spline (3 = cubic, recommended)
    
    Returns
    -------
    basis : np.ndarray
        Matrix of shape (n_bins, n_basis) with basis function values
    """
    # Create knot sequence
    # Internal knots equally spaced across time range
    internal_knots = np.linspace(0, n_bins-1, n_knots)
    
    # Add boundary knots (repeated degree+1 times for clamped spline)
    knots = np.concatenate([
        np.repeat(internal_knots[0], degree),
        internal_knots,
        np.repeat(internal_knots[-1], degree)
    ])
    
    # Number of basis functions
    n_basis = len(knots) - degree - 1
    
    # Evaluate each basis function at each time point
    t = np.arange(n_bins, dtype=float)
    basis = np.zeros((n_bins, n_basis))
    
    for i in range(n_basis):
        # Create coefficient vector (indicator for basis i)
        coef = np.zeros(n_basis)
        coef[i] = 1.0
        
        # Evaluate B-spline
        spline = BSpline(knots, coef, degree, extrapolate=False)
        basis[:, i] = spline(t)
    
    return basis

n_knots = 10
basis = create_bspline_basis(n_bins, n_knots=n_knots, degree=3)
n_cols = basis.shape[1]
basis_df = pd.DataFrame(basis, columns=[f'feature_{i}' for i in range(n_cols)])
basis_df.head(10)
feature_0 feature_1 feature_2 feature_3 feature_4 feature_5 feature_6 feature_7 feature_8 feature_9 feature_10 feature_11
0 1.000000 0.000000 0.000000 0.000000 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
1 0.863149 0.133496 0.003337 0.000018 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
2 0.739389 0.247518 0.012946 0.000146 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
3 0.628064 0.343219 0.028223 0.000494 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
4 0.528515 0.421749 0.048566 0.001170 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
5 0.440083 0.484261 0.073370 0.002286 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
6 0.362110 0.531908 0.102032 0.003950 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
7 0.293939 0.565840 0.133949 0.006272 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
8 0.234909 0.587211 0.168518 0.009362 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
9 0.184365 0.597171 0.205134 0.013330 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
n_basis = basis.shape[1]
n_obs = data['df_long'].shape[0]
time_bins = data['bins']['bin_idx'].values

observed_mediator = df_long["m_c"].values
observed_events = df_long['event'].astype(int).values
observed_treatment = df_long['x'].astype(int).values
observed_mediator_lag = df_long['m_lag'].values


coords = {'tv': ['intercept', 'direct', 'mediator'], 
          'splines': ['spline_f_{i}' for i in range(n_basis)], 
          'obs': range(n_obs), 
          'time_bins': time_bins}

with pm.Model(coords=coords) as aalen_dpa_model:

    trt = pm.Data("trt", observed_treatment, dims="obs")
    med = pm.Data("mediator", observed_mediator, dims="obs")
    med_lag = pm.Data("mediator_lag", observed_mediator_lag, dims="obs")
    events = pm.Data("events", observed_events, dims="obs")
    I_low  = pm.Data("I_low",  df_long["I_low"].values,  dims="obs")
    I_high = pm.Data("I_high", df_long["I_high"].values, dims="obs")
    dt = pm.Data("duration", df_long['dt'].values, dims='obs')
    ## because our long data format has a cell per obs
    at_risk = pm.Data("at_risk", np.ones(len(observed_events)), dims="obs")
    basis_ = pm.Data("basis", basis_df.values, dims=('time_bins', 'splines') )

    # -------------------------------------------------
    # 1. B-spline coefficients for HAZARD model
    # -------------------------------------------------
    # Prior on spline coefficients
    # Smaller sigma = less wiggliness
    # Random Walk 1 (RW1) Prior for coefficients
    # This is the Bayesian version of the smoothing penalty in R's 'mgcv' or 'timereg'
    sigma_smooth = pm.Exponential("sigma_smooth", [1, 1, 1], dims='tv')
    beta_raw = pm.Normal("beta_raw", 0, 1, dims=('splines', 'tv'))

    # Cumulative sum makes it a Random Walk
    # This ensures coefficients evolve smoothly over time
    coef_alpha = pm.Deterministic("coef_alpha", pt.cumsum(beta_raw * sigma_smooth, axis=0), dims=('splines', 'tv'))

    # Construct smooth time-varying functions
    alpha_0_t = pt.dot(basis_, coef_alpha[:, 0])
    alpha_1_t = pt.dot(basis_, coef_alpha[:, 1])
    alpha_2_t = pt.dot(basis_, coef_alpha[:, 2])
    
    # -------------------------------------------------
    # 2. B-spline coefficients for MEDIATOR model
    # -------------------------------------------------
    sigma_beta_smooth = pm.Exponential("sigma_beta_smooth", 0.1)
    beta_raw = pm.Normal("beta_raw_m", 0, 1, dims=('splines'))
    coef_beta = pt.cumsum(beta_raw * sigma_beta_smooth)
    
    beta_t = pt.dot(basis_, coef_beta)

    # -------------------------------------------------
    # 3. Mediator model (A path: x → M)
    # -------------------------------------------------
    sigma_m = pm.HalfNormal("sigma_m", 1.0)
    
    # Autoregressive component
    rho = pm.Beta("rho", 2, 2)
    
    mu_m = beta_t[b] * trt + rho * med_lag

    pm.Normal(
        "obs_m",
        mu=mu_m,
        sigma=sigma_m,
        observed=med,
        dims='obs'
    )

    # -------------------------------------------------
    # 4. Hazard model (direct + B path)
    # -------------------------------------------------
    beta_low  = pm.Normal("beta_low",  0, 0.1)
    beta_high = pm.Normal("beta_high", 0, 0.1)
    # Log-additive hazard
    log_lambda_t = (alpha_0_t[b] 
                    + alpha_1_t[b] * trt # direct effect
                    + alpha_2_t[b] * med  # mediator effect
                    + beta_low  * I_low
                    + beta_high * I_high
    )
    
    # Expected number of events
    time_at_risk = at_risk * dt
    Lambda = time_at_risk * pm.math.log1pexp(log_lambda_t)

    pm.Poisson(
        "obs_event",
        mu=Lambda,
        observed=events, 
        dims='obs'
    )

    # -------------------------------------------------
    # 5. Causal path effects
    # -------------------------------------------------
    # Store time-varying coefficients
    pm.Deterministic("alpha_0_t", alpha_0_t, dims='time_bins')
    pm.Deterministic("alpha_1_t", alpha_1_t, dims='time_bins')  # direct effect
    pm.Deterministic("alpha_2_t", alpha_2_t, dims='time_bins')  # B path
    pm.Deterministic("beta_t", beta_t, dims='time_bins')        # A path
    
    # Cumulative direct effect
    cum_de = pm.Deterministic(
        "tv_direct_effect",
        alpha_1_t, 
        dims='time_bins'
    )

    # Cumulative indirect effect (product of paths)
    cum_ie = pm.Deterministic(
        "tv_indirect_effect",
        beta_t * alpha_2_t, 
        dims='time_bins'
    )

    # Total effect
    cum_te = pm.Deterministic(
        "tv_total_effect",
        cum_de + cum_ie,
        dims='time_bins'
    )

    # -------------------------------------------------
    # 6. Sample
    # -------------------------------------------------
    trace_additive = pm.sample(
        draws=2000,
        tune=2000,
        target_accept=0.95,
        chains=4,
        nuts_sampler="numpyro",
        random_seed=42,
        init="adapt_diag"
    )
There were 12 divergences after tuning. Increase `target_accept` or reparameterize.
pm.model_to_graphviz(aalen_dpa_model)

vars_to_plot = ['tv_direct_effect', 'tv_indirect_effect', 'tv_total_effect']
labels = ['Time varying Direct Effect', 'Time varying Indirect Effect', 'Time varying Total Effect']

fig, axs = plt.subplots(1, 3, figsize=(20, 6))

for i, var in enumerate(vars_to_plot):
    # 1. Extract the posterior samples for this variable
    # Shape will be (chain * draw, time)
    post_samples = az.extract(trace_additive, var_names=[var]).values.T
    
    # 2. Calculate the mean and the 94% HDI across the chains/draws
    mean_val = post_samples.mean(axis=0)
    hdi_val = az.hdi(post_samples, hdi_prob=0.94) # Returns [time, 2] array
    
    # 3. Plot the Mean line
    x_axis = np.arange(len(mean_val))
    axs[i].plot(x_axis, mean_val, label=labels[i], color='teal', lw=2)
    
    # 4. Plot the Shaded HDI region
    axs[i].fill_between(x_axis, hdi_val[:, 0], hdi_val[:, 1], color='teal', alpha=0.2, label='94% HDI')
    
    # Formatting
    axs[i].set_title(labels[i])
    axs[i].legend()
    axs[i].grid(alpha=0.3)
    # If you have a time vector (e.g., days), replace x_axis with your time values

plt.tight_layout()
plt.show()
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_28758/127617944.py:13: FutureWarning: hdi currently interprets 2d data as (draw, shape) but this will change in a future release to (chain, draw) for coherence with other functions
  hdi_val = az.hdi(post_samples, hdi_prob=0.94) # Returns [time, 2] array
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_28758/127617944.py:13: FutureWarning: hdi currently interprets 2d data as (draw, shape) but this will change in a future release to (chain, draw) for coherence with other functions
  hdi_val = az.hdi(post_samples, hdi_prob=0.94) # Returns [time, 2] array
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_28758/127617944.py:13: FutureWarning: hdi currently interprets 2d data as (draw, shape) but this will change in a future release to (chain, draw) for coherence with other functions
  hdi_val = az.hdi(post_samples, hdi_prob=0.94) # Returns [time, 2] array

Citation

BibTeX citation:
@online{forde,
  author = {Forde, Nathaniel},
  title = {Aalen’s {Dynamic} {Path} {Model}},
  langid = {en}
}
For attribution, please cite this work as:
Forde, Nathaniel. n.d. “Aalen’s Dynamic Path Model.”